/* Copyright (c) 1996-2022 The OPC Foundation. All rights reserved.
   The source code in this file is covered under a dual-license scenario:
     - RCL: for OPC Foundation Corporate Members in good-standing
     - GPL V2: everybody else
   RCL license terms accompanied with this source code. See http://opcfoundation.org/License/RCL/1.00/
   GNU General Public License as published by the Free Software Foundation;
   version 2 of the License are accompanied with this source code. See http://opcfoundation.org/License/GPLv2
   This source code is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
*/

using System;
using System.IO;
using System.Net;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Net.Security;
using System.Reflection;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
#if NETSTANDARD2_1 || NET472_OR_GREATER || NET5_0_OR_GREATER
using System.Security.Cryptography;
#endif

namespace Opc.Ua.Bindings
{
    /// <summary>
    /// Creates a new HttpsTransportChannel with ITransportChannel interface.
    /// </summary>
    public class HttpsTransportChannelFactory : ITransportChannelFactory
    {
        /// <summary>
        /// The protocol supported by the channel.
        /// </summary>
        public string UriScheme => Utils.UriSchemeHttps;

        /// <summary>
        /// The method creates a new instance of a Https transport channel
        /// </summary>
        /// <returns>The transport channel</returns>
        public ITransportChannel Create()
        {
            return new HttpsTransportChannel(UriScheme);
        }
    }

    /// <summary>
    /// Creates a new HttpsTransportChannel with ITransportChannel interface.
    /// The uri scheme opc.https is used.
    /// </summary>
    public class OpcHttpsTransportChannelFactory : ITransportChannelFactory
    {
        /// <summary>
        /// The protocol supported by the channel.
        /// </summary>
        public string UriScheme => Utils.UriSchemeOpcHttps;

        /// <summary>
        /// The method creates a new instance of a Https transport channel
        /// </summary>
        /// <returns>The transport channel</returns>
        public ITransportChannel Create()
        {
            return new HttpsTransportChannel(UriScheme);
        }
    }

    /// <summary>
    /// Wraps the HttpsTransportChannel and provides an ITransportChannel implementation.
    /// </summary>
    public class HttpsTransportChannel : ITransportChannel
    {
        /// <summary>
        /// limit the number of concurrent service requests on the server
        /// </summary>
        private const int kMaxConnectionsPerServer = 64;

        /// <summary>
        /// Create a transport channel based on the uri scheme.
        /// </summary>
        public HttpsTransportChannel(string uriScheme)
        {
            UriScheme = uriScheme;
        }

        /// <inheritdoc/>
        public void Dispose()
        {
            Dispose(true);
            GC.SuppressFinalize(this);
        }

        /// <summary>
        /// Override this method if you need to release resources.
        /// </summary>
        protected virtual void Dispose(bool disposing)
        {
        }

        /// <inheritdoc/>
        public string UriScheme { get; }

        /// <inheritdoc/>
        public TransportChannelFeatures SupportedFeatures =>
            TransportChannelFeatures.Open |
            TransportChannelFeatures.Reconnect |
            TransportChannelFeatures.BeginSendRequest |
            TransportChannelFeatures.SendRequestAsync;

        /// <inheritdoc/>
        public EndpointDescription EndpointDescription => m_settings.Description;

        /// <inheritdoc/>
        public EndpointConfiguration EndpointConfiguration => m_settings.Configuration;

        /// <inheritdoc/>
        public IServiceMessageContext MessageContext => m_quotas.MessageContext;

        /// <inheritdoc/>
        public ChannelToken CurrentToken => null;

        /// <inheritdoc/>
        public event ChannelTokenActivatedEventHandler OnTokenActivated
        {
            add
            {
            }
            remove
            {
            }
        }

        /// <inheritdoc/>
        public int OperationTimeout { get; set; }

        /// <inheritdoc/>
        public void Initialize(Uri url, TransportChannelSettings settings)
        {
            SaveSettings(url, settings);
        }

        /// <summary>
        /// Initializes a secure channel with a waiting reverse connection.
        /// </summary>
        /// <param name="connection">The connection to use.</param>
        /// <param name="settings">The settings to use when creating the channel.</param>
        /// <exception cref="ServiceResultException">Thrown if any communication error occurs.</exception>
        public void Initialize(
            ITransportWaitingConnection connection,
            TransportChannelSettings settings)
        {
            SaveSettings(connection.EndpointUrl, settings);
        }

        /// <inheritdoc/>
        public void Open()
        {
            try
            {
                Utils.LogInfo("{0} Open {1}.", nameof(HttpsTransportChannel), m_url);

                // auto validate server cert, if supported
                // if unsupported, the TLS server cert must be trusted by a root CA
                var handler = new HttpClientHandler
                {
                    ClientCertificateOptions = ClientCertificateOption.Manual,
                    AllowAutoRedirect = false,
                    MaxRequestContentBufferSize = m_quotas.MaxMessageSize
                };

                // limit the number of concurrent connections, if supported
                PropertyInfo maxConnectionsPerServerProperty = handler.GetType()
                    .GetProperty("MaxConnectionsPerServer");
                maxConnectionsPerServerProperty?.SetValue(handler, kMaxConnectionsPerServer);

                // send client certificate for servers that require TLS client authentication
                if (m_settings.ClientCertificate != null)
                {
                    // prepare the server TLS certificate
                    X509Certificate2 clientCertificate = m_settings.ClientCertificate;
#if NETSTANDARD2_1 || NET472_OR_GREATER || NET5_0_OR_GREATER
                    try
                    {
                        // Create a copy of the certificate with the private key on platforms
                        // which default to the ephemeral KeySet. Also a new certificate must be reloaded.
                        // If the key fails to copy, its probably a non exportable key from the X509Store.
                        // Then we can use the original certificate, the private key is already in the key store.
                        clientCertificate = X509Utils.CreateCopyWithPrivateKey(
                            m_settings.ClientCertificate,
                            false);
                    }
                    catch (CryptographicException ce)
                    {
                        Utils.LogTrace(
                            "Copy of the private key for https was denied: {0}",
                            ce.Message);
                    }
#endif
                    PropertyInfo certProperty = handler.GetType().GetProperty("ClientCertificates");
                    if (certProperty != null)
                    {
                        var clientCertificates = (X509CertificateCollection)certProperty.GetValue(
                            handler);
                        _ = clientCertificates?.Add(clientCertificate);
                    }
                }

                PropertyInfo propertyInfo = handler.GetType()
                    .GetProperty("ServerCertificateCustomValidationCallback");
                if (propertyInfo != null)
                {
                    Func<
                        HttpRequestMessage,
                        X509Certificate2,
                        X509Chain,
                        SslPolicyErrors,
                        bool
                    > serverCertificateCustomValidationCallback;

                    try
                    {
                        serverCertificateCustomValidationCallback = (_, cert, chain, _) =>
                        {
                            try
                            {
                                var validationChain = new X509Certificate2Collection();
                                if (chain != null && chain.ChainElements != null)
                                {
                                    int i = 0;
                                    Utils.LogInfo(
                                        Utils.TraceMasks.Security,
                                        "{0} Validate server chain:",
                                        nameof(HttpsTransportChannel));
                                    foreach (X509ChainElement element in chain.ChainElements)
                                    {
                                        Utils.LogCertificate(
                                            Utils.TraceMasks.Security,
                                            "{0}: ",
                                            element.Certificate,
                                            i);
                                        validationChain.Add(element.Certificate);
                                        i++;
                                    }
                                }
                                else
                                {
                                    Utils.LogCertificate(
                                        Utils.TraceMasks.Security,
                                        "{0} Validate Server Certificate: ",
                                        cert,
                                        nameof(HttpsTransportChannel));
                                    validationChain.Add(cert);
                                }

                                m_quotas.CertificateValidator?.Validate(validationChain);

                                return true;
                            }
                            catch (Exception ex)
                            {
                                Utils.LogError(
                                    ex,
                                    "{0} Failed to validate certificate.",
                                    nameof(HttpsTransportChannel));
                            }
                            return false;
                        };
                        propertyInfo.SetValue(handler, serverCertificateCustomValidationCallback);

                        Utils.LogInfo(
                            "{0} ServerCertificate callback enabled.",
                            nameof(HttpsTransportChannel));
                    }
                    catch (PlatformNotSupportedException)
                    {
                        // client may throw if not supported (e.g. UWP)
                        serverCertificateCustomValidationCallback = null;
                    }
                }

                m_client = new HttpClient(handler);
            }
            catch (Exception ex)
            {
                Utils.LogError(ex, "Exception creating HTTPS Client.");
                throw;
            }
        }

        /// <inheritdoc/>
        public void Close()
        {
            Utils.LogInfo("{0} Close {1}.", nameof(HttpsTransportChannel), m_url);
            m_client?.Dispose();
        }

        /// <inheritdoc/>
        public Task CloseAsync(CancellationToken ct)
        {
            Close();
            return Task.CompletedTask;
        }

        /// <summary>
        /// The async result class for the Https transport.
        /// </summary>
        private class HttpsAsyncResult : AsyncResultBase
        {
            public IServiceRequest Request;
            public HttpResponseMessage Response;

            public HttpsAsyncResult(
                AsyncCallback callback,
                object callbackData,
                int timeout,
                IServiceRequest request,
                HttpResponseMessage response)
                : base(callback, callbackData, timeout)
            {
                Request = request;
                Response = response;
            }
        }

        /// <inheritdoc/>
        public IAsyncResult BeginSendRequest(
            IServiceRequest request,
            AsyncCallback callback,
            object callbackData)
        {
            HttpResponseMessage response = null;

            try
            {
                var content = new ByteArrayContent(
                    BinaryEncoder.EncodeMessage(request, m_quotas.MessageContext));
                content.Headers.ContentType = s_mediaTypeHeaderValue;
                if (EndpointDescription?.SecurityPolicyUri != null &&
                    !string.Equals(
                        EndpointDescription.SecurityPolicyUri,
                        SecurityPolicies.None,
                        StringComparison.Ordinal))
                {
                    content.Headers.Add(
                        Profiles.HttpsSecurityPolicyHeader,
                        EndpointDescription.SecurityPolicyUri);
                }

                var result = new HttpsAsyncResult(
                    callback,
                    callbackData,
                    OperationTimeout,
                    request,
                    null);

                _ = Task.Run(async () =>
                {
                    try
                    {
                        using var cts = new CancellationTokenSource(OperationTimeout);
                        response = await m_client.PostAsync(m_url, content, cts.Token)
                            .ConfigureAwait(false);
                        response.EnsureSuccessStatusCode();
                    }
                    catch (Exception ex)
                    {
                        Utils.LogError(ex, "Exception sending HTTPS request.");
                        result.Exception = ex;
                        response = null;
                    }
                    result.Response = response;
                    result.OperationCompleted();
                });

                return result;
            }
            catch (Exception ex)
            {
                Utils.LogError(ex, "Exception sending HTTPS request.");
                var result = new HttpsAsyncResult(
                    callback,
                    callbackData,
                    OperationTimeout,
                    request,
                    response)
                {
                    Exception = ex
                };
                result.OperationCompleted();
                return result;
            }
        }

        /// <inheritdoc/>
        public IServiceResponse EndSendRequest(IAsyncResult result)
        {
            if (result is not HttpsAsyncResult result2)
            {
                throw new ArgumentException("Invalid result object passed.", nameof(result));
            }

            try
            {
                result2.WaitForComplete();
                if (result2.Response != null)
                {
#if NET6_0_OR_GREATER
                    Stream responseContent = result2.Response.Content.ReadAsStream();
#else
                    Stream responseContent = result2.Response.Content.ReadAsStreamAsync()
                        .GetAwaiter()
                        .GetResult();
#endif
                    return BinaryDecoder.DecodeMessage(
                        responseContent,
                        null,
                        m_quotas.MessageContext) as
                        IServiceResponse;
                }
            }
            catch (Exception ex)
            {
                Utils.LogError(ex, "Exception reading HTTPS response.");
                result2.Exception = ex;
            }
            return result2 as IServiceResponse;
        }

        /// <inheritdoc/>
        public Task<IServiceResponse> EndSendRequestAsync(IAsyncResult result, CancellationToken ct)
        {
            throw new NotImplementedException();
        }

        /// <inheritdoc/>
        /// <remarks>Not implemented here.</remarks>
        public IAsyncResult BeginOpen(AsyncCallback callback, object callbackData)
        {
            throw new NotImplementedException();
        }

        /// <inheritdoc/>
        /// <remarks>Not implemented here.</remarks>
        public void EndOpen(IAsyncResult result)
        {
            throw new NotImplementedException();
        }

        /// <inheritdoc/>
        /// <remarks>Not implemented here.</remarks>
        public void Reconnect()
        {
            Utils.LogInfo("HttpsTransportChannel RECONNECT: Reconnecting to {0}.", m_url);
        }

        /// <inheritdoc/>
        /// <remarks>Not implemented here.</remarks>
        void ITransportChannel.Reconnect(ITransportWaitingConnection connection)
        {
            throw new NotImplementedException();
        }

        /// <inheritdoc/>
        /// <remarks>Not implemented here.</remarks>
        public IAsyncResult BeginReconnect(AsyncCallback callback, object callbackData)
        {
            throw new NotImplementedException();
        }

        /// <inheritdoc/>
        /// <remarks>Not implemented here.</remarks>
        public void EndReconnect(IAsyncResult result)
        {
            throw new NotImplementedException();
        }

        /// <inheritdoc/>
        /// <remarks>Not implemented here.</remarks>
        public IAsyncResult BeginClose(AsyncCallback callback, object callbackData)
        {
            throw new NotImplementedException();
        }

        /// <inheritdoc/>
        /// <remarks>Not implemented here.</remarks>
        public void EndClose(IAsyncResult result)
        {
            throw new NotImplementedException();
        }

        /// <inheritdoc/>
        public IServiceResponse SendRequest(IServiceRequest request)
        {
            IAsyncResult result = BeginSendRequest(request, null, null);
            return EndSendRequest(result);
        }

        /// <inheritdoc/>
        public async Task<IServiceResponse> SendRequestAsync(
            IServiceRequest request,
            CancellationToken ct)
        {
            try
            {
                var content = new ByteArrayContent(
                    BinaryEncoder.EncodeMessage(request, m_quotas.MessageContext));
                content.Headers.ContentType = s_mediaTypeHeaderValue;
                if (EndpointDescription?.SecurityPolicyUri != null &&
                    !string.Equals(
                        EndpointDescription.SecurityPolicyUri,
                        SecurityPolicies.None,
                        StringComparison.Ordinal))
                {
                    content.Headers.Add(
                        Profiles.HttpsSecurityPolicyHeader,
                        EndpointDescription.SecurityPolicyUri);
                }

                HttpResponseMessage response;
                using (var cts = new CancellationTokenSource(OperationTimeout))
                using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(
                    cts.Token,
                    ct))
                {
                    response = await m_client.PostAsync(m_url, content, linkedCts.Token)
                        .ConfigureAwait(false);
                }
                response.EnsureSuccessStatusCode();

#if NET6_0_OR_GREATER
                Stream responseContent = await response.Content.ReadAsStreamAsync(ct)
                    .ConfigureAwait(false);
#else
                Stream responseContent = await response.Content.ReadAsStreamAsync()
                    .ConfigureAwait(false);
#endif
                return BinaryDecoder.DecodeMessage(
                    responseContent,
                    null,
                    m_quotas.MessageContext) as IServiceResponse;
            }
            catch (HttpRequestException hre)
            {
                if (hre.InnerException is WebException webex)
                {
                    StatusCode statusCode = StatusCodes.BadUnknownResponse;
                    switch (webex.Status)
                    {
                        case WebExceptionStatus.Timeout:
                            statusCode = StatusCodes.BadRequestTimeout;
                            break;
                        case WebExceptionStatus.ConnectionClosed:
                        case WebExceptionStatus.ConnectFailure:
                            statusCode = StatusCodes.BadNotConnected;
                            break;
                    }
                    Utils.LogError(webex, "Exception sending HTTPS request.");
                    throw ServiceResultException.Create((uint)statusCode, webex.Message);
                }
                Utils.LogError(hre, "Exception sending HTTPS request.");
                throw;
            }
            catch (TaskCanceledException tce)
            {
                Utils.LogError(tce, "Send request cancelled.");
                throw ServiceResultException.Create(
                    StatusCodes.BadRequestTimeout,
                    "Https request was cancelled.");
            }
            catch (Exception ex)
            {
                Utils.LogError(ex, "Exception sending HTTPS request.");
                throw ServiceResultException.Create(StatusCodes.BadUnknownResponse, ex.Message);
            }
        }

        /// <summary>
        /// Save the settings for a connection.
        /// </summary>
        /// <param name="url">The server url.</param>
        /// <param name="settings">The settings for the transport channel.</param>
        private void SaveSettings(Uri url, TransportChannelSettings settings)
        {
            m_url = new Uri(url.ToString());
            // remove the opc. prefix, the https client can not handle it
            if (m_url.Scheme == Utils.UriSchemeOpcHttps)
            {
                m_url = new Uri(url.ToString()[4..]);
            }
            m_settings = settings;
            OperationTimeout = settings.Configuration.OperationTimeout;

            // initialize the quotas.
            m_quotas = new ChannelQuotas
            {
                MaxBufferSize = m_settings.Configuration.MaxBufferSize,
                MaxMessageSize = m_settings.Configuration.MaxMessageSize,
                ChannelLifetime = m_settings.Configuration.ChannelLifetime,
                SecurityTokenLifetime = m_settings.Configuration.SecurityTokenLifetime,

                MessageContext = new ServiceMessageContext
                {
                    MaxArrayLength = m_settings.Configuration.MaxArrayLength,
                    MaxByteStringLength = m_settings.Configuration.MaxByteStringLength,
                    MaxMessageSize = m_settings.Configuration.MaxMessageSize,
                    MaxStringLength = m_settings.Configuration.MaxStringLength,
                    MaxEncodingNestingLevels = m_settings.Configuration.MaxEncodingNestingLevels,
                    MaxDecoderRecoveries = m_settings.Configuration.MaxDecoderRecoveries,
                    NamespaceUris = m_settings.NamespaceUris,
                    ServerUris = new StringTable(),
                    Factory = m_settings.Factory
                },

                CertificateValidator = settings.CertificateValidator
            };
        }

        private Uri m_url;
        private TransportChannelSettings m_settings;
        private ChannelQuotas m_quotas;
        private HttpClient m_client;

        private static readonly MediaTypeHeaderValue s_mediaTypeHeaderValue = new(
            "application/octet-stream");
    }
}
